{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TTS Inference\n",
    "\n",
    "This notebook can be used to generate audio samples using either NeMo's pretrained models or after training NeMo TTS models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "> Copyright 2020 NVIDIA. All Rights Reserved.\n",
    "> \n",
    "> Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "> you may not use this file except in compliance with the License.\n",
    "> You may obtain a copy of the License at\n",
    "> \n",
    ">     http://www.apache.org/licenses/LICENSE-2.0\n",
    "> \n",
    "> Unless required by applicable law or agreed to in writing, software\n",
    "> distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "> See the License for the specific language governing permissions and\n",
    "> limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "# # If you're using Google Colab and not running locally, uncomment and run this cell.\n",
    "# !apt-get install sox libsndfile1 ffmpeg\n",
    "# !pip install wget unidecode\n",
    "# BRANCH = 'v1.0.0'\n",
    "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[tts]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models\n",
    "\n",
    "First we pick the models that we want to use. Currently supported models are:\n",
    "\n",
    "End-to-End Models:\n",
    "- [FastPitch_HifiGan_E2E](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_en_e2e_fastpitchhifigan)\n",
    "- [FastSpeech2_HifiGan_E2E](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_en_e2e_fastspeech2hifigan)\n",
    "\n",
    "Mel Spectrogram Generators:\n",
    "- [Tacotron 2](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_en_tacotron2)\n",
    "- [Glow-TTS](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_en_glowtts)\n",
    "- [TalkNet](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_en_talknet)\n",
    "- [FastPitch](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_en_fastpitch)\n",
    "- [FastSpeech2](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_en_fastspeech_2)\n",
    "\n",
    "Audio Generators\n",
    "- [WaveGlow](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_waveglow_88m)\n",
    "- [SqueezeWave](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_squeezewave)\n",
    "- [UniGlow](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_uniglow)\n",
    "- [MelGAN](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_melgan)\n",
    "- [HiFiGAN](https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_hifigan)\n",
    "- Griffin-Lim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from ipywidgets import Select, HBox, Label\n",
    "from IPython.display import display\n",
    "\n",
    "supported_e2e = [\"fastpitch_hifigan\", \"fastspeech2_hifigan\", None]\n",
    "supported_spec_gen = [\"tacotron2\", \"glow_tts\", \"talknet\", \"fastpitch\", \"fastspeech2\", None]\n",
    "supported_audio_gen = [\"waveglow\", \"squeezewave\", \"uniglow\", \"melgan\", \"hifigan\", \"griffin-lim\", None]\n",
    "\n",
    "print(\"Select the model(s) that you want to use. Please choose either 1 end-to-end model or 1 spectrogram generator and 1 vocoder.\")\n",
    "e2e_selector = Select(options=supported_e2e, value=None)\n",
    "spectrogram_generator_selector = Select(options=supported_spec_gen, value=None)\n",
    "audio_generator_selector = Select(options=supported_audio_gen, value=None)\n",
    "display(HBox([e2e_selector, Label(\"OR\"), spectrogram_generator_selector, Label(\"+\"), audio_generator_selector]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "e2e_model = e2e_selector.value\n",
    "spectrogram_generator = spectrogram_generator_selector.value\n",
    "audio_generator = audio_generator_selector.value\n",
    "\n",
    "if e2e_model is None and spectrogram_generator is None and audio_generator is None:\n",
    "    raise ValueError(\"No models were chosen. Please return to the previous step and choose either 1 end-to-end model or 1 spectrogram generator and 1 vocoder.\")\n",
    "\n",
    "if e2e_model and (spectrogram_generator or audio_generator):\n",
    "    raise ValueError(\n",
    "        \"An end-to-end model was chosen and either a spectrogram generator or a vocoder was also selected. For end-to-end models, please select `None` \"\n",
    "        \"in the second and third column to continue. For the two step pipeline, please select `None` in the first column to continue.\"\n",
    "    )\n",
    "\n",
    "if (spectrogram_generator and audio_generator is None) or (audio_generator and spectrogram_generator is None):\n",
    "    raise ValueError(\"In order to continue with the two step pipeline, both the spectrogram generator and the audio generator must be chosen, but one was `None`\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model checkpoints\n",
    "\n",
    "Next we load the pretrained model provided by NeMo. All NeMo models have two functions to help with this\n",
    "\n",
    "- list_available_models(): This function will return a list of all pretrained checkpoints for that model\n",
    "- from_pretrained(): This function will download the pretrained checkpoint, load it, and return an instance of the model\n",
    "\n",
    "Below we will use `from_pretrained` to load the chosen models from above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false,
    "tags": []
   },
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf, open_dict\n",
    "import torch\n",
    "from nemo.collections.tts.models.base import SpectrogramGenerator, Vocoder, TextToWaveform\n",
    "\n",
    "\n",
    "def load_spectrogram_model():\n",
    "    override_conf = None\n",
    "    if spectrogram_generator == \"tacotron2\":\n",
    "        from nemo.collections.tts.models import Tacotron2Model\n",
    "        pretrained_model = \"tts_en_tacotron2\"\n",
    "    elif spectrogram_generator == \"glow_tts\":\n",
    "        from nemo.collections.tts.models import GlowTTSModel\n",
    "        pretrained_model = \"tts_en_glowtts\"\n",
    "        import wget\n",
    "        from pathlib import Path\n",
    "        if not Path(\"cmudict-0.7b\").exists():\n",
    "            filename = wget.download(\"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b\")\n",
    "            filename = str(Path(filename).resolve())\n",
    "        else:\n",
    "            filename = str(Path(\"cmudict-0.7b\").resolve())\n",
    "        conf = SpectrogramGenerator.from_pretrained(pretrained_model, return_config=True)\n",
    "        if \"params\" in conf.parser:\n",
    "            conf.parser.params.cmu_dict_path = filename\n",
    "        else:\n",
    "            conf.parser.cmu_dict_path = filename\n",
    "        override_conf = conf\n",
    "    elif spectrogram_generator == \"talknet\":\n",
    "        from nemo.collections.tts.models import TalkNetSpectModel\n",
    "        pretrained_model = \"tts_en_talknet\" \n",
    "    elif spectrogram_generator == \"fastpitch\":\n",
    "        from nemo.collections.tts.models import FastPitchModel\n",
    "        pretrained_model = \"tts_en_fastpitch\"\n",
    "    elif spectrogram_generator == \"fastspeech2\":\n",
    "        from nemo.collections.tts.models import FastSpeech2Model\n",
    "        pretrained_model = \"tts_en_fastspeech2\"\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "    model = SpectrogramGenerator.from_pretrained(pretrained_model, override_config_path=override_conf)\n",
    "    \n",
    "    if spectrogram_generator == \"talknet\":\n",
    "        from nemo.collections.tts.models import TalkNetPitchModel\n",
    "        pitch_model = TalkNetPitchModel.from_pretrained(pretrained_model, override_config_path=override_conf)\n",
    "        from nemo.collections.tts.models import TalkNetDursModel\n",
    "        durs_model = TalkNetDursModel.from_pretrained(pretrained_model, override_config_path=override_conf)\n",
    "        model.add_module('_pitch_model', pitch_model)\n",
    "        model.add_module('_durs_model', durs_model)\n",
    "    \n",
    "    return model\n",
    "\n",
    "\n",
    "def load_vocoder_model():\n",
    "    RequestPseudoInverse = False\n",
    "    TwoStagesModel = False\n",
    "    strict=True\n",
    "    \n",
    "    if audio_generator == \"waveglow\":\n",
    "        from nemo.collections.tts.models import WaveGlowModel\n",
    "        pretrained_model = \"tts_waveglow\"\n",
    "        strict=False\n",
    "    elif audio_generator == \"squeezewave\":\n",
    "        from nemo.collections.tts.models import SqueezeWaveModel\n",
    "        pretrained_model = \"tts_squeezewave\"\n",
    "    elif audio_generator == \"uniglow\":\n",
    "        from nemo.collections.tts.models import UniGlowModel\n",
    "        pretrained_model = \"tts_uniglow\"\n",
    "    elif audio_generator == \"melgan\":\n",
    "        from nemo.collections.tts.models import MelGanModel\n",
    "        pretrained_model = \"tts_melgan\"\n",
    "    elif audio_generator == \"hifigan\":\n",
    "        from nemo.collections.tts.models import HifiGanModel\n",
    "        pretrained_model = \"tts_hifigan\"\n",
    "    elif audio_generator == \"griffin-lim\":\n",
    "        from nemo.collections.tts.models import TwoStagesModel\n",
    "        cfg = {'linvocoder':  {'_target_': 'nemo.collections.tts.models.two_stages.GriffinLimModel',\n",
    "                             'cfg': {'n_iters': 64, 'n_fft': 1024, 'l_hop': 256}},\n",
    "               'mel2spec': {'_target_': 'nemo.collections.tts.models.two_stages.MelPsuedoInverseModel',\n",
    "                           'cfg': {'sampling_rate': 22050, 'n_fft': 1024, \n",
    "                                   'mel_fmin': 0, 'mel_fmax': 8000, 'mel_freq': 80}}}\n",
    "        model = TwoStagesModel(cfg)            \n",
    "        TwoStagesModel = True\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "    if not TwoStagesModel:\n",
    "        model = Vocoder.from_pretrained(pretrained_model, strict=strict)\n",
    "    return model\n",
    "\n",
    "def load_e2e_model():\n",
    "    if e2e_model == \"fastpitch_hifigan\":\n",
    "        from nemo.collections.tts.models import FastPitchHifiGanE2EModel\n",
    "        pretrained_model = \"tts_en_e2e_fastpitchhifigan\"\n",
    "    elif e2e_model == \"fastspeech2_hifigan\":\n",
    "        from nemo.collections.tts.models import FastSpeech2HifiGanE2EModel\n",
    "        pretrained_model = \"tts_en_e2e_fastspeech2hifigan\"\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "    model = TextToWaveform.from_pretrained(pretrained_model)\n",
    "    return model\n",
    "\n",
    "emodel = None\n",
    "spec_gen = None\n",
    "vocoder = None\n",
    "if e2e_model:\n",
    "    emodel = load_e2e_model().eval().cuda()\n",
    "else:\n",
    "    spec_gen = load_spectrogram_model().eval().cuda()\n",
    "    vocoder = load_vocoder_model().eval().cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "Now that we have downloaded the model checkpoints and loaded them into memory. Let's define a short infer helper function that takes a string, and our models to produce speech.\n",
    "\n",
    "Notice that the NeMo TTS model interface is fairly simple and standardized across all models.\n",
    "\n",
    "End-to-end models have two helper functions:\n",
    "- parse(): Accepts raw python strings and returns a torch.tensor that represents tokenized text\n",
    "- convert_text_to_waveform(): Accepts a batch of tokenized text and returns a torch.tensor that represents a batch of raw audio\n",
    "\n",
    "Mel Spectrogram generators have two helper functions:\n",
    "\n",
    "- parse(): Accepts raw python strings and returns a torch.tensor that represents tokenized text\n",
    "- generate_spectrogram(): Accepts a batch of tokenized text and returns a torch.tensor that represents a batch of spectrograms\n",
    "\n",
    "Vocoder have just one helper function:\n",
    "\n",
    "- convert_spectrogram_to_audio(): Accepts a batch of spectrograms and returns a torch.tensor that represents a batch of raw audio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def infer(emodel, spec_gen_model, vocder_model, str_input):\n",
    "    parser_model = emodel or spec_gen_model\n",
    "    with torch.no_grad():\n",
    "        parsed = parser_model.parse(str_input)\n",
    "        if emodel is None:\n",
    "            spectrogram = spec_gen.generate_spectrogram(tokens=parsed)\n",
    "            audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)\n",
    "        else:\n",
    "            spectrogram = None\n",
    "            audio = emodel.convert_text_to_waveform(tokens=parsed)[0]\n",
    "    if spectrogram is not None:\n",
    "        if isinstance(spectrogram, torch.Tensor):\n",
    "            spectrogram = spectrogram.to('cpu').numpy()\n",
    "        if len(spectrogram.shape) == 3:\n",
    "            spectrogram = spectrogram[0]\n",
    "    if isinstance(audio, torch.Tensor):\n",
    "        audio = audio.to('cpu').numpy()\n",
    "    return spectrogram, audio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that everything is set up, let's give an input that we want our models to speak"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_to_generate = input(\"Input what you want the model to say: \")\n",
    "spec, audio = infer(emodel, spec_gen, vocoder, text_to_generate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Results\n",
    "\n",
    "After our model generates the audio, let's go ahead and play it. We can also visualize the spectrogram that was produced from the first stage model if a spectrogram generator was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython.display as ipd\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from matplotlib.pyplot import imshow\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "ipd.Audio(audio, rate=22050)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "if spec is not None:\n",
    "    imshow(spec, origin=\"lower\")\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}